{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "human-moral",
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入相关依赖库\n",
    "import  os\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import mindspore as ms\n",
    "#context模块用于设置实验环境和实验设备\n",
    "import mindspore.context as context\n",
    "#dataset模块用于处理数据形成数据集\n",
    "import mindspore.dataset as ds\n",
    "#c_transforms模块用于转换数据类型\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "#vision.c_transforms模块用于转换图像，这是一个基于opencv的高级API\n",
    "import mindspore.dataset.vision.c_transforms as CV\n",
    "#导入Accuracy作为评价指标\n",
    "from mindspore.nn.metrics import Accuracy\n",
    "#nn中有各种神经网络层如：Dense，ReLu\n",
    "from mindspore import nn\n",
    "#Model用于创建模型对象，完成网络搭建和编译，并用于训练和评估\n",
    "from mindspore.train import Model\n",
    "#LossMonitor可以在训练过程中返回LOSS值作为监控指标\n",
    "from mindspore.train.callback import  LossMonitor\n",
    "import importlib\n",
    "import MyOptimizer\n",
    "# from MyOptimizer import GdOptimizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "radical-investigator",
   "metadata": {},
   "outputs": [],
   "source": [
    "#MindSpore内置方法读取MNIST数据集\n",
    "ds_train = ds.MnistDataset(os.path.join(r'./MNIST', \"train\"))\n",
    "ds_test = ds.MnistDataset(os.path.join(r'./MNIST', \"test\")) \n",
    "\n",
    "print('训练数据集数量：',ds_train.get_dataset_size())\n",
    "print('测试数据集数量：',ds_test.get_dataset_size())\n",
    "#该数据集可以通过create_dict_iterator()转换为迭代器形式，然后通过__next__()一个个输出样本\n",
    "image=ds_train.create_dict_iterator().__next__()\n",
    "print(type(image))\n",
    "print('图像长/宽/通道数：',image['image'].shape)\n",
    "#一共10类，用0-9的数字表达类别。\n",
    "print('一张图像的标签样式：',image['label'])  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "southwest-movie",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR_TRAIN = \"./MNIST/train\" # 训练集信息\n",
    "DATA_DIR_TEST = \"./MNIST/test\" # 测试集信息\n",
    "\n",
    "def create_dataset(training=True, batch_size=128, resize=(28, 28),rescale=1/255, shift=-0.5, buffer_size=64):\n",
    "    ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)\n",
    "    \n",
    "    #定义改变形状、归一化和更改图片维度的操作。\n",
    "    #改为（28,28）的形状\n",
    "    resize_op = CV.Resize(resize)\n",
    "    #rescale方法可以对数据集进行归一化和标准化操作，这里就是将像素值归一到0和1之间，shift参数可以让值域偏移至-0.5和0.5之间\n",
    "    rescale_op = CV.Rescale(rescale, shift)\n",
    "    #由高度、宽度、深度改为深度、高度、宽度\n",
    "    hwc2chw_op = CV.HWC2CHW()\n",
    "    \n",
    "    # 利用map操作对原数据集进行调整\n",
    "    ds = ds.map(input_columns=\"image\", operations=[resize_op, rescale_op, hwc2chw_op])\n",
    "    ds = ds.map(input_columns=\"label\", operations=C.TypeCast(ms.int32))\n",
    "    #设定洗牌缓冲区的大小，从一定程度上控制打乱操作的混乱程度\n",
    "    ds = ds.shuffle(buffer_size=buffer_size)\n",
    "    #设定数据集的batch_size大小，并丢弃剩余的样本\n",
    "    ds = ds.batch(batch_size, drop_remainder=True)\n",
    "    \n",
    "    return ds\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "considered-hampshire",
   "metadata": {},
   "outputs": [],
   "source": [
    "#显示前10张图片以及对应标签,检查图片是否是正确的数据集\n",
    "dataset_show = create_dataset(training=False)\n",
    "data = dataset_show.create_dict_iterator().__next__()\n",
    "images = data['image'].asnumpy()\n",
    "labels = data['label'].asnumpy()\n",
    "\n",
    "for i in range(1,11):\n",
    "    plt.subplot(2, 5, i)\n",
    "    #利用squeeze方法去掉多余的一个维度\n",
    "    plt.imshow(np.squeeze(images[i]))\n",
    "    plt.title('Number: %s' % labels[i])\n",
    "    plt.xticks([])\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "polyphonic-stewart",
   "metadata": {},
   "outputs": [],
   "source": [
    "#利用定义类的方式生成网络，Mindspore中定义网络需要继承nn.cell。在init方法中定义该网络需要的神经网络层\n",
    "#在construct方法中梳理神经网络层与层之间的关系。\n",
    "class ForwardNN(nn.Cell):      \n",
    "    def __init__(self):\n",
    "        super(ForwardNN, self).__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc1 = nn.Dense(784, 512, activation='relu')\n",
    "        self.fc2 = nn.Dense(512, 256, activation='relu')\n",
    "        self.fc3 = nn.Dense(256, 128, activation='relu')\n",
    "        self.fc4 = nn.Dense(128, 64, activation='relu')\n",
    "        self.fc5 = nn.Dense(64, 32, activation='relu')\n",
    "        self.fc6 = nn.Dense(32, 10, activation='softmax')\n",
    "    \n",
    "    def construct(self, input_x):\n",
    "        output = self.flatten(input_x)\n",
    "        output = self.fc1(output)\n",
    "        output = self.fc2(output)\n",
    "        output = self.fc3(output)   \n",
    "        output = self.fc4(output)\n",
    "        output = self.fc5(output)\n",
    "        output = self.fc6(output)\n",
    "        return output \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "important-consequence",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.1\n",
    "num_epoch = 10\n",
    "momentum = 0.9\n",
    "\n",
    "net = ForwardNN()\n",
    "#定义loss函数，改函数不需要求导，可以给离散的标签值，且loss值为均值\n",
    "loss = nn.loss.SoftmaxCrossEntropyWithLogits( sparse=True, reduction='mean')\n",
    "#定义准确率为评价指标，用于评价模型\n",
    "metrics={\"Accuracy\": Accuracy()}\n",
    "#定义优化器为Adam优化器，并设定学习率\n",
    "# opt = nn.Adam(net.trainable_params(), lr)\n",
    "importlib.reload(MyOptimizer)\n",
    "opt = MyOptimizer.GdOptimizer(net.trainable_params(), lr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "understood-tracker",
   "metadata": {},
   "outputs": [],
   "source": [
    "#生成验证集，验证机不需要训练，所以不需要repeat\n",
    "ds_eval = create_dataset(False, batch_size=32)\n",
    "#模型编译过程，将定义好的网络、loss函数、评价指标、优化器编译\n",
    "model = Model(net, loss, opt, metrics)\n",
    "\n",
    "#生成训练集\n",
    "ds_train = create_dataset(True, batch_size=32)\n",
    "print(\"============== Starting Training ==============\")\n",
    "#训练模型，用loss作为监控指标，并利用昇腾芯片的数据下沉特性进行训练\n",
    "model.train(num_epoch, ds_train,callbacks=[LossMonitor()],dataset_sink_mode=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sexual-disabled",
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用测试集评估模型，打印总体准确率\n",
    "metrics_result=model.eval(ds_eval)\n",
    "print(metrics_result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "persistent-crest",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "distributedml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
